{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Online prediction with BigQuery ML\n",
    "\n",
    "ML.Predict in BigQuery ML is primarily meant for batch predictions. What if you want to build a web application to provide online predictions? Here, I show the basic Python code to do online prediction. You can wrap this code in AppEngine or other web framework/toolkit to provide scalable, fast, online predictions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "!pip install google-cloud   # Reset Session after installing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "PROJECT = 'cloud-training-demos'  # change as needed"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create a model\n",
    "\n",
    "Let's start by creating a simple prediction model to [predict arrival delays of aircraft](https://towardsdatascience.com/how-to-train-and-predict-regression-and-classification-ml-models-using-only-sql-using-bigquery-ml-f219b180b947). I'll use this to illustrate the process.\n",
    "\n",
    "First, if necessary, create the BigQuery dataset that will store the output of the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%bash\n",
    "bq mk -d flights"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then, do a \"CREATE MODEL\". This will take about <b>5 minutes</b>."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "Done"
      ],
      "text/plain": [
       "QueryResultsTable job_1EM_XyCZV-9AGbXua8W72EfKMZ1g"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%bq query\n",
    "\n",
    "CREATE OR REPLACE MODEL flights.arrdelay\n",
    "OPTIONS\n",
    "  (model_type='linear_reg', input_label_cols=['arr_delay']) AS\n",
    "SELECT\n",
    "  arr_delay,\n",
    "  carrier,\n",
    "  origin,\n",
    "  dest,\n",
    "  dep_delay,\n",
    "  taxi_out,\n",
    "  distance\n",
    "FROM\n",
    "  `cloud-training-demos.flights.tzcorr`\n",
    "WHERE\n",
    "  arr_delay IS NOT NULL"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Batch prediction with model\n",
    "\n",
    "Once you have a trained model, batch prediction can be done within BigQuery itself.\n",
    "\n",
    "For example, to find the predicted arrival delays for a flight from DFW to LAX for a range of departure delays"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div class=\"bqtv\" id=\"4_153400967753\"><table><tr><th>predicted_arr_delay</th><th>carrier</th><th>origin</th><th>dest</th><th>dep_delay</th><th>taxi_out</th><th>distance</th></tr><tr><td>-8.31085723761</td><td>AA</td><td>DFW</td><td>LAX</td><td>-3</td><td>18</td><td>1235</td></tr><tr><td>-7.32016373789</td><td>AA</td><td>DFW</td><td>LAX</td><td>-2</td><td>18</td><td>1235</td></tr><tr><td>-6.32947023816</td><td>AA</td><td>DFW</td><td>LAX</td><td>-1</td><td>18</td><td>1235</td></tr><tr><td>-5.33877673844</td><td>AA</td><td>DFW</td><td>LAX</td><td>0</td><td>18</td><td>1235</td></tr><tr><td>-4.34808323871</td><td>AA</td><td>DFW</td><td>LAX</td><td>1</td><td>18</td><td>1235</td></tr><tr><td>-3.35738973899</td><td>AA</td><td>DFW</td><td>LAX</td><td>2</td><td>18</td><td>1235</td></tr><tr><td>-2.36669623926</td><td>AA</td><td>DFW</td><td>LAX</td><td>3</td><td>18</td><td>1235</td></tr><tr><td>-1.37600273954</td><td>AA</td><td>DFW</td><td>LAX</td><td>4</td><td>18</td><td>1235</td></tr><tr><td>-0.385309239811</td><td>AA</td><td>DFW</td><td>LAX</td><td>5</td><td>18</td><td>1235</td></tr><tr><td>0.605384259914</td><td>AA</td><td>DFW</td><td>LAX</td><td>6</td><td>18</td><td>1235</td></tr><tr><td>1.59607775964</td><td>AA</td><td>DFW</td><td>LAX</td><td>7</td><td>18</td><td>1235</td></tr><tr><td>2.58677125936</td><td>AA</td><td>DFW</td><td>LAX</td><td>8</td><td>18</td><td>1235</td></tr><tr><td>3.57746475909</td><td>AA</td><td>DFW</td><td>LAX</td><td>9</td><td>18</td><td>1235</td></tr><tr><td>4.56815825881</td><td>AA</td><td>DFW</td><td>LAX</td><td>10</td><td>18</td><td>1235</td></tr></table></div>\n",
       "    <br />(rows: 14, time: 1.2s,    16KB processed, job: job_WVW0Xk2F-GX7NphFVwFZXwenWOIM)<br />\n",
       "    <script src=\"/static/components/requirejs/require.js\"></script>\n",
       "    <script>\n",
       "      require.config({\n",
       "        paths: {\n",
       "          base: '/static/base',\n",
       "          d3: '//cdnjs.cloudflare.com/ajax/libs/d3/3.4.13/d3',\n",
       "          plotly: 'https://cdn.plot.ly/plotly-1.5.1.min.js?noext',\n",
       "          jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min'\n",
       "        },\n",
       "        map: {\n",
       "          '*': {\n",
       "            datalab: 'nbextensions/gcpdatalab'\n",
       "          }\n",
       "        },\n",
       "        shim: {\n",
       "          plotly: {\n",
       "            deps: ['d3', 'jquery'],\n",
       "            exports: 'plotly'\n",
       "          }\n",
       "        }\n",
       "      });\n",
       "\n",
       "      require(['datalab/charting', 'datalab/element!4_153400967753', 'base/js/events',\n",
       "          'datalab/style!/nbextensions/gcpdatalab/charting.css'],\n",
       "        function(charts, dom, events) {\n",
       "          charts.render('gcharts', dom, events, 'table', [], {\"rows\": [{\"c\": [{\"v\": -8.310857237611458}, {\"v\": \"AA\"}, {\"v\": \"DFW\"}, {\"v\": \"LAX\"}, {\"v\": -3}, {\"v\": 18}, {\"v\": 1235}]}, {\"c\": [{\"v\": -7.320163737886418}, {\"v\": \"AA\"}, {\"v\": \"DFW\"}, {\"v\": \"LAX\"}, {\"v\": -2}, {\"v\": 18}, {\"v\": 1235}]}, {\"c\": [{\"v\": -6.329470238161379}, {\"v\": \"AA\"}, {\"v\": \"DFW\"}, {\"v\": \"LAX\"}, {\"v\": -1}, {\"v\": 18}, {\"v\": 1235}]}, {\"c\": [{\"v\": -5.338776738436341}, {\"v\": \"AA\"}, {\"v\": \"DFW\"}, {\"v\": \"LAX\"}, {\"v\": 0}, {\"v\": 18}, {\"v\": 1235}]}, {\"c\": [{\"v\": -4.348083238711302}, {\"v\": \"AA\"}, {\"v\": \"DFW\"}, {\"v\": \"LAX\"}, {\"v\": 1}, {\"v\": 18}, {\"v\": 1235}]}, {\"c\": [{\"v\": -3.3573897389862637}, {\"v\": \"AA\"}, {\"v\": \"DFW\"}, {\"v\": \"LAX\"}, {\"v\": 2}, {\"v\": 18}, {\"v\": 1235}]}, {\"c\": [{\"v\": -2.3666962392612243}, {\"v\": \"AA\"}, {\"v\": \"DFW\"}, {\"v\": \"LAX\"}, {\"v\": 3}, {\"v\": 18}, {\"v\": 1235}]}, {\"c\": [{\"v\": -1.3760027395361851}, {\"v\": \"AA\"}, {\"v\": \"DFW\"}, {\"v\": \"LAX\"}, {\"v\": 4}, {\"v\": 18}, {\"v\": 1235}]}, {\"c\": [{\"v\": -0.3853092398111464}, {\"v\": \"AA\"}, {\"v\": \"DFW\"}, {\"v\": \"LAX\"}, {\"v\": 5}, {\"v\": 18}, {\"v\": 1235}]}, {\"c\": [{\"v\": 0.6053842599138926}, {\"v\": \"AA\"}, {\"v\": \"DFW\"}, {\"v\": \"LAX\"}, {\"v\": 6}, {\"v\": 18}, {\"v\": 1235}]}, {\"c\": [{\"v\": 1.5960777596389315}, {\"v\": \"AA\"}, {\"v\": \"DFW\"}, {\"v\": \"LAX\"}, {\"v\": 7}, {\"v\": 18}, {\"v\": 1235}]}, {\"c\": [{\"v\": 2.58677125936397}, {\"v\": \"AA\"}, {\"v\": \"DFW\"}, {\"v\": \"LAX\"}, {\"v\": 8}, {\"v\": 18}, {\"v\": 1235}]}, {\"c\": [{\"v\": 3.5774647590890094}, {\"v\": \"AA\"}, {\"v\": \"DFW\"}, {\"v\": \"LAX\"}, {\"v\": 9}, {\"v\": 18}, {\"v\": 1235}]}, {\"c\": [{\"v\": 4.568158258814048}, {\"v\": \"AA\"}, {\"v\": \"DFW\"}, {\"v\": \"LAX\"}, {\"v\": 10}, {\"v\": 18}, {\"v\": 1235}]}], \"cols\": [{\"type\": \"number\", \"id\": \"predicted_arr_delay\", \"label\": \"predicted_arr_delay\"}, {\"type\": \"string\", \"id\": \"carrier\", \"label\": \"carrier\"}, {\"type\": \"string\", \"id\": \"origin\", \"label\": \"origin\"}, {\"type\": \"string\", \"id\": \"dest\", \"label\": \"dest\"}, {\"type\": \"number\", \"id\": \"dep_delay\", \"label\": \"dep_delay\"}, {\"type\": \"number\", \"id\": \"taxi_out\", \"label\": \"taxi_out\"}, {\"type\": \"number\", \"id\": \"distance\", \"label\": \"distance\"}]},\n",
       "            {\n",
       "              pageSize: 25,\n",
       "              cssClassNames:  {\n",
       "                tableRow: 'gchart-table-row',\n",
       "                headerRow: 'gchart-table-headerrow',\n",
       "                oddTableRow: 'gchart-table-oddrow',\n",
       "                selectedTableRow: 'gchart-table-selectedrow',\n",
       "                hoverTableRow: 'gchart-table-hoverrow',\n",
       "                tableCell: 'gchart-table-cell',\n",
       "                headerCell: 'gchart-table-headercell',\n",
       "                rowNumberCell: 'gchart-table-rownumcell'\n",
       "              }\n",
       "            },\n",
       "            {source_index: 3, fields: 'predicted_arr_delay,carrier,origin,dest,dep_delay,taxi_out,distance'},\n",
       "            0,\n",
       "            14);\n",
       "        }\n",
       "      );\n",
       "    </script>\n",
       "  "
      ],
      "text/plain": [
       "QueryResultsTable job_WVW0Xk2F-GX7NphFVwFZXwenWOIM"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%bq query\n",
    "SELECT * FROM ml.PREDICT(MODEL flights.arrdelay, (\n",
    "SELECT \n",
    "  'AA' as carrier,\n",
    "  'DFW' as origin,\n",
    "  'LAX' as dest,\n",
    "  dep_delay,\n",
    "  18 as taxi_out,\n",
    "  1235 as distance\n",
    "FROM\n",
    "  UNNEST(GENERATE_ARRAY(-3, 10)) as dep_delay\n",
    "))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Online prediction in Python\n",
    "\n",
    "The batch prediction technique above can not be used for online prediction though.  Typical BigQuery queries  have a latency of 1-2 seconds and that is too high for a web application.\n",
    "\n",
    "For online prediction, it is better to grab the weights and do the computation yourself."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div class=\"bqtv\" id=\"8_153404735236\"><table><tr><th>input</th><th>input_weight</th></tr><tr><td>carrier</td><td>&nbsp;</td></tr><tr><td>origin</td><td>&nbsp;</td></tr><tr><td>dest</td><td>&nbsp;</td></tr><tr><td>dep_delay</td><td>36.5569545237</td></tr><tr><td>taxi_out</td><td>8.15557957221</td></tr><tr><td>distance</td><td>-1.88324519311</td></tr><tr><td>__INTERCEPT__</td><td>1.09017737502</td></tr></table></div>\n",
       "    <br />(rows: 7, time: 1.3s,    16KB processed, job: job_5YCmyG8UY5qfA8aNa9hy_RQBgzI1)<br />\n",
       "    <script src=\"/static/components/requirejs/require.js\"></script>\n",
       "    <script>\n",
       "      require.config({\n",
       "        paths: {\n",
       "          base: '/static/base',\n",
       "          d3: '//cdnjs.cloudflare.com/ajax/libs/d3/3.4.13/d3',\n",
       "          plotly: 'https://cdn.plot.ly/plotly-1.5.1.min.js?noext',\n",
       "          jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min'\n",
       "        },\n",
       "        map: {\n",
       "          '*': {\n",
       "            datalab: 'nbextensions/gcpdatalab'\n",
       "          }\n",
       "        },\n",
       "        shim: {\n",
       "          plotly: {\n",
       "            deps: ['d3', 'jquery'],\n",
       "            exports: 'plotly'\n",
       "          }\n",
       "        }\n",
       "      });\n",
       "\n",
       "      require(['datalab/charting', 'datalab/element!8_153404735236', 'base/js/events',\n",
       "          'datalab/style!/nbextensions/gcpdatalab/charting.css'],\n",
       "        function(charts, dom, events) {\n",
       "          charts.render('gcharts', dom, events, 'table', [], {\"rows\": [{\"c\": [{\"v\": \"carrier\"}, {\"v\": null}]}, {\"c\": [{\"v\": \"origin\"}, {\"v\": null}]}, {\"c\": [{\"v\": \"dest\"}, {\"v\": null}]}, {\"c\": [{\"v\": \"dep_delay\"}, {\"v\": 36.55695452367766}]}, {\"c\": [{\"v\": \"taxi_out\"}, {\"v\": 8.15557957221127}]}, {\"c\": [{\"v\": \"distance\"}, {\"v\": -1.8832451931138876}]}, {\"c\": [{\"v\": \"__INTERCEPT__\"}, {\"v\": 1.0901773750231667}]}], \"cols\": [{\"type\": \"string\", \"id\": \"input\", \"label\": \"input\"}, {\"type\": \"number\", \"id\": \"input_weight\", \"label\": \"input_weight\"}]},\n",
       "            {\n",
       "              pageSize: 25,\n",
       "              cssClassNames:  {\n",
       "                tableRow: 'gchart-table-row',\n",
       "                headerRow: 'gchart-table-headerrow',\n",
       "                oddTableRow: 'gchart-table-oddrow',\n",
       "                selectedTableRow: 'gchart-table-selectedrow',\n",
       "                hoverTableRow: 'gchart-table-hoverrow',\n",
       "                tableCell: 'gchart-table-cell',\n",
       "                headerCell: 'gchart-table-headercell',\n",
       "                rowNumberCell: 'gchart-table-rownumcell'\n",
       "              }\n",
       "            },\n",
       "            {source_index: 7, fields: 'input,input_weight'},\n",
       "            0,\n",
       "            7);\n",
       "        }\n",
       "      );\n",
       "    </script>\n",
       "  "
      ],
      "text/plain": [
       "QueryResultsTable job_5YCmyG8UY5qfA8aNa9hy_RQBgzI1"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%bq query\n",
    "SELECT\n",
    "  processed_input AS input,\n",
    "  model.weight AS input_weight\n",
    "FROM\n",
    "  ml.WEIGHTS(MODEL flights.arrdelay) AS model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div class=\"bqtv\" id=\"6_153404722666\"><table><tr><th>input</th><th>input_weight</th><th>category_name</th><th>category_weight</th></tr><tr><td>carrier</td><td>&nbsp;</td><td>B6</td><td>2.41360895323</td></tr><tr><td>carrier</td><td>&nbsp;</td><td>EV</td><td>2.27889474393</td></tr><tr><td>carrier</td><td>&nbsp;</td><td>AA</td><td>-0.0548843604154</td></tr><tr><td>carrier</td><td>&nbsp;</td><td>US</td><td>1.82939081237</td></tr><tr><td>carrier</td><td>&nbsp;</td><td>DL</td><td>-1.28516116856</td></tr><tr><td>carrier</td><td>&nbsp;</td><td>F9</td><td>5.76556618145</td></tr><tr><td>carrier</td><td>&nbsp;</td><td>WN</td><td>2.18336049823</td></tr><tr><td>carrier</td><td>&nbsp;</td><td>HA</td><td>6.17381375535</td></tr><tr><td>carrier</td><td>&nbsp;</td><td>VX</td><td>5.05422774942</td></tr><tr><td>carrier</td><td>&nbsp;</td><td>AS</td><td>2.022529139</td></tr><tr><td>carrier</td><td>&nbsp;</td><td>UA</td><td>-1.9492217572</td></tr><tr><td>carrier</td><td>&nbsp;</td><td>OO</td><td>1.11110679975</td></tr><tr><td>carrier</td><td>&nbsp;</td><td>NK</td><td>7.09579009395</td></tr><tr><td>carrier</td><td>&nbsp;</td><td>MQ</td><td>1.19620456586</td></tr><tr><td>origin</td><td>&nbsp;</td><td>WRG</td><td>8.99721250183</td></tr><tr><td>origin</td><td>&nbsp;</td><td>MSY</td><td>4.96015143256</td></tr><tr><td>origin</td><td>&nbsp;</td><td>AMA</td><td>6.20849983281</td></tr><tr><td>origin</td><td>&nbsp;</td><td>SAF</td><td>4.19530797642</td></tr><tr><td>origin</td><td>&nbsp;</td><td>GUC</td><td>3.60929742614</td></tr><tr><td>origin</td><td>&nbsp;</td><td>MKG</td><td>2.23621423776</td></tr><tr><td>origin</td><td>&nbsp;</td><td>CHO</td><td>3.63391263959</td></tr><tr><td>origin</td><td>&nbsp;</td><td>ESC</td><td>1.01325898826</td></tr><tr><td>origin</td><td>&nbsp;</td><td>MCI</td><td>4.16187598831</td></tr><tr><td>origin</td><td>&nbsp;</td><td>GSP</td><td>3.06866659933</td></tr><tr><td>origin</td><td>&nbsp;</td><td>GJT</td><td>3.6366332815</td></tr></table></div>\n",
       "    <br />(rows: 658, time: 1.1s,    16KB processed, job: job_h3LTu2B6REiDJWXHslEMZ9MyIjH_)<br />\n",
       "    <script src=\"/static/components/requirejs/require.js\"></script>\n",
       "    <script>\n",
       "      require.config({\n",
       "        paths: {\n",
       "          base: '/static/base',\n",
       "          d3: '//cdnjs.cloudflare.com/ajax/libs/d3/3.4.13/d3',\n",
       "          plotly: 'https://cdn.plot.ly/plotly-1.5.1.min.js?noext',\n",
       "          jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min'\n",
       "        },\n",
       "        map: {\n",
       "          '*': {\n",
       "            datalab: 'nbextensions/gcpdatalab'\n",
       "          }\n",
       "        },\n",
       "        shim: {\n",
       "          plotly: {\n",
       "            deps: ['d3', 'jquery'],\n",
       "            exports: 'plotly'\n",
       "          }\n",
       "        }\n",
       "      });\n",
       "\n",
       "      require(['datalab/charting', 'datalab/element!6_153404722666', 'base/js/events',\n",
       "          'datalab/style!/nbextensions/gcpdatalab/charting.css'],\n",
       "        function(charts, dom, events) {\n",
       "          charts.render('gcharts', dom, events, 'paged_table', [], {\"rows\": [{\"c\": [{\"v\": \"carrier\"}, {\"v\": null}, {\"v\": \"B6\"}, {\"v\": 2.4136089532266967}]}, {\"c\": [{\"v\": \"carrier\"}, {\"v\": null}, {\"v\": \"EV\"}, {\"v\": 2.278894743926057}]}, {\"c\": [{\"v\": \"carrier\"}, {\"v\": null}, {\"v\": \"AA\"}, {\"v\": -0.05488436041544667}]}, {\"c\": [{\"v\": \"carrier\"}, {\"v\": null}, {\"v\": \"US\"}, {\"v\": 1.8293908123702811}]}, {\"c\": [{\"v\": \"carrier\"}, {\"v\": null}, {\"v\": \"DL\"}, {\"v\": -1.2851611685565674}]}, {\"c\": [{\"v\": \"carrier\"}, {\"v\": null}, {\"v\": \"F9\"}, {\"v\": 5.765566181448997}]}, {\"c\": [{\"v\": \"carrier\"}, {\"v\": null}, {\"v\": \"WN\"}, {\"v\": 2.1833604982347374}]}, {\"c\": [{\"v\": \"carrier\"}, {\"v\": null}, {\"v\": \"HA\"}, {\"v\": 6.173813755352824}]}, {\"c\": [{\"v\": \"carrier\"}, {\"v\": null}, {\"v\": \"VX\"}, {\"v\": 5.054227749424656}]}, {\"c\": [{\"v\": \"carrier\"}, {\"v\": null}, {\"v\": \"AS\"}, {\"v\": 2.022529138999248}]}, {\"c\": [{\"v\": \"carrier\"}, {\"v\": null}, {\"v\": \"UA\"}, {\"v\": -1.9492217572004094}]}, {\"c\": [{\"v\": \"carrier\"}, {\"v\": null}, {\"v\": \"OO\"}, {\"v\": 1.1111067997490025}]}, {\"c\": [{\"v\": \"carrier\"}, {\"v\": null}, {\"v\": \"NK\"}, {\"v\": 7.09579009395471}]}, {\"c\": [{\"v\": \"carrier\"}, {\"v\": null}, {\"v\": \"MQ\"}, {\"v\": 1.1962045658633098}]}, {\"c\": [{\"v\": \"origin\"}, {\"v\": null}, {\"v\": \"WRG\"}, {\"v\": 8.99721250183298}]}, {\"c\": [{\"v\": \"origin\"}, {\"v\": null}, {\"v\": \"MSY\"}, {\"v\": 4.960151432563211}]}, {\"c\": [{\"v\": \"origin\"}, {\"v\": null}, {\"v\": \"AMA\"}, {\"v\": 6.208499832813039}]}, {\"c\": [{\"v\": \"origin\"}, {\"v\": null}, {\"v\": \"SAF\"}, {\"v\": 4.1953079764207715}]}, {\"c\": [{\"v\": \"origin\"}, {\"v\": null}, {\"v\": \"GUC\"}, {\"v\": 3.609297426138213}]}, {\"c\": [{\"v\": \"origin\"}, {\"v\": null}, {\"v\": \"MKG\"}, {\"v\": 2.236214237755088}]}, {\"c\": [{\"v\": \"origin\"}, {\"v\": null}, {\"v\": \"CHO\"}, {\"v\": 3.633912639589261}]}, {\"c\": [{\"v\": \"origin\"}, {\"v\": null}, {\"v\": \"ESC\"}, {\"v\": 1.013258988264886}]}, {\"c\": [{\"v\": \"origin\"}, {\"v\": null}, {\"v\": \"MCI\"}, {\"v\": 4.161875988312131}]}, {\"c\": [{\"v\": \"origin\"}, {\"v\": null}, {\"v\": \"GSP\"}, {\"v\": 3.068666599332962}]}, {\"c\": [{\"v\": \"origin\"}, {\"v\": null}, {\"v\": \"GJT\"}, {\"v\": 3.6366332815027502}]}], \"cols\": [{\"type\": \"string\", \"id\": \"input\", \"label\": \"input\"}, {\"type\": \"number\", \"id\": \"input_weight\", \"label\": \"input_weight\"}, {\"type\": \"string\", \"id\": \"category_name\", \"label\": \"category_name\"}, {\"type\": \"number\", \"id\": \"category_weight\", \"label\": \"category_weight\"}]},\n",
       "            {\n",
       "              pageSize: 25,\n",
       "              cssClassNames:  {\n",
       "                tableRow: 'gchart-table-row',\n",
       "                headerRow: 'gchart-table-headerrow',\n",
       "                oddTableRow: 'gchart-table-oddrow',\n",
       "                selectedTableRow: 'gchart-table-selectedrow',\n",
       "                hoverTableRow: 'gchart-table-hoverrow',\n",
       "                tableCell: 'gchart-table-cell',\n",
       "                headerCell: 'gchart-table-headercell',\n",
       "                rowNumberCell: 'gchart-table-rownumcell'\n",
       "              }\n",
       "            },\n",
       "            {source_index: 5, fields: 'input,input_weight,category_name,category_weight'},\n",
       "            0,\n",
       "            658);\n",
       "        }\n",
       "      );\n",
       "    </script>\n",
       "  "
      ],
      "text/plain": [
       "QueryResultsTable job_h3LTu2B6REiDJWXHslEMZ9MyIjH_"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%bq query\n",
    "SELECT\n",
    "  processed_input AS input,\n",
    "  model.weight AS input_weight,\n",
    "  category.category AS category_name,\n",
    "  category.weight AS category_weight\n",
    "FROM\n",
    "  ml.WEIGHTS(MODEL flights.arrdelay) AS model,\n",
    "  UNNEST(category_weights) AS category"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here's how to do that in Python. \n",
    " \n",
    "p.s. I'm assuming that you are in an environment where you are already authenticated with Google Cloud. If not, see [this article on how to access BigQuery using private keys](https://towardsdatascience.com/how-to-enable-pandas-to-access-bigquery-from-a-service-account-205a216f0f68)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def query_to_dataframe(query):\n",
    "  import pandas as pd\n",
    "  import pkgutil\n",
    "  privatekey = None # pkgutil.get_data(KEYDIR, 'privatekey.json')\n",
    "  return pd.read_gbq(query,\n",
    "                     project_id=PROJECT,\n",
    "                     dialect='standard',\n",
    "                     private_key=privatekey)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You need to pull 3 pieces of information:\n",
    "* The weights for each of your numeric columns\n",
    "* The scaling for each of your numeric columns\n",
    "* The vocabulary and weights for each of your categorical columns\n",
    "\n",
    "I pull them using three separate BigQuery queries below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requesting query... ok.\n",
      "Job ID: 4f89a94d-11a1-4cfa-bbb2-f046ffe63df1\n",
      "Query running...\n",
      "Query done.\n",
      "Cache hit.\n",
      "\n",
      "Retrieving results...\n",
      "Got 7 rows.\n",
      "\n",
      "Total time taken 1.02 s.\n",
      "Finished at 2018-08-12 04:20:46.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>input</th>\n",
       "      <th>input_weight</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>dep_delay</td>\n",
       "      <td>36.556955</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>taxi_out</td>\n",
       "      <td>8.155580</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>distance</td>\n",
       "      <td>-1.883245</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>__INTERCEPT__</td>\n",
       "      <td>1.090177</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "           input  input_weight\n",
       "3      dep_delay     36.556955\n",
       "4       taxi_out      8.155580\n",
       "5       distance     -1.883245\n",
       "6  __INTERCEPT__      1.090177"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "numeric_query = \"\"\"\n",
    "SELECT\n",
    "  processed_input AS input,\n",
    "  model.weight AS input_weight\n",
    "FROM\n",
    "  ml.WEIGHTS(MODEL flights.arrdelay) AS model\n",
    "\"\"\"\n",
    "numeric_weights = query_to_dataframe(numeric_query).dropna()\n",
    "numeric_weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requesting query... ok.\n",
      "Job ID: a3d9e2fb-f940-48ba-a5aa-2e7645d82db1\n",
      "Query running...\n",
      "Query done.\n",
      "Processed: 0.0 B Billed: 0.0 B\n",
      "Standard price: $0.00 USD\n",
      "\n",
      "Retrieving results...\n",
      "Got 6 rows.\n",
      "\n",
      "Total time taken 2.63 s.\n",
      "Finished at 2018-08-12 05:00:15.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>input</th>\n",
       "      <th>min</th>\n",
       "      <th>max</th>\n",
       "      <th>mean</th>\n",
       "      <th>stddev</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>dep_delay</td>\n",
       "      <td>-82.0</td>\n",
       "      <td>1988.0</td>\n",
       "      <td>9.169095</td>\n",
       "      <td>36.900368</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>taxi_out</td>\n",
       "      <td>1.0</td>\n",
       "      <td>225.0</td>\n",
       "      <td>16.099878</td>\n",
       "      <td>8.901454</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>distance</td>\n",
       "      <td>31.0</td>\n",
       "      <td>4983.0</td>\n",
       "      <td>825.795621</td>\n",
       "      <td>608.756947</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       input   min     max        mean      stddev\n",
       "3  dep_delay -82.0  1988.0    9.169095   36.900368\n",
       "4   taxi_out   1.0   225.0   16.099878    8.901454\n",
       "5   distance  31.0  4983.0  825.795621  608.756947"
      ]
     },
     "execution_count": 84,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "scaling_query = \"\"\"\n",
    "SELECT\n",
    "  input, min, max, mean, stddev\n",
    "FROM\n",
    "  ml.FEATURE_INFO(MODEL flights.arrdelay) AS model\n",
    "\"\"\"\n",
    "scaling_df = query_to_dataframe(scaling_query).dropna()\n",
    "scaling_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requesting query... ok.\n",
      "Job ID: b662c5fd-32af-4624-88db-d7aeed97300f\n",
      "Query running...\n",
      "Query done.\n",
      "Cache hit.\n",
      "\n",
      "Retrieving results...\n",
      "Got 658 rows.\n",
      "\n",
      "Total time taken 1.03 s.\n",
      "Finished at 2018-08-12 04:17:36.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>input</th>\n",
       "      <th>input_weight</th>\n",
       "      <th>category_name</th>\n",
       "      <th>category_weight</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>carrier</td>\n",
       "      <td>NaN</td>\n",
       "      <td>B6</td>\n",
       "      <td>2.413609</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>carrier</td>\n",
       "      <td>NaN</td>\n",
       "      <td>EV</td>\n",
       "      <td>2.278895</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>carrier</td>\n",
       "      <td>NaN</td>\n",
       "      <td>AA</td>\n",
       "      <td>-0.054884</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>carrier</td>\n",
       "      <td>NaN</td>\n",
       "      <td>US</td>\n",
       "      <td>1.829391</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>carrier</td>\n",
       "      <td>NaN</td>\n",
       "      <td>DL</td>\n",
       "      <td>-1.285161</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "     input  input_weight category_name  category_weight\n",
       "0  carrier           NaN            B6         2.413609\n",
       "1  carrier           NaN            EV         2.278895\n",
       "2  carrier           NaN            AA        -0.054884\n",
       "3  carrier           NaN            US         1.829391\n",
       "4  carrier           NaN            DL        -1.285161"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "categorical_query = \"\"\"\n",
    "SELECT\n",
    "  processed_input AS input,\n",
    "  model.weight AS input_weight,\n",
    "  category.category AS category_name,\n",
    "  category.weight AS category_weight\n",
    "FROM\n",
    "  ml.WEIGHTS(MODEL flights.arrdelay) AS model,\n",
    "  UNNEST(category_weights) AS category\n",
    "\"\"\"\n",
    "categorical_weights = query_to_dataframe(categorical_query)\n",
    "categorical_weights.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With the three pieces of information in-place, you can simply do the math for linear regression:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def compute_prediction(rowdict, numeric_weights, scaling_df, categorical_weights):\n",
    "  input_values = rowdict\n",
    "  # numeric inputs\n",
    "  pred = 0\n",
    "  for column_name in numeric_weights['input'].unique():\n",
    "    wt = numeric_weights[ numeric_weights['input'] == column_name ]['input_weight'].values[0]\n",
    "    if column_name != '__INTERCEPT__':\n",
    "      #minv = scaling_df[ scaling_df['input'] == column_name ]['min'].values[0]\n",
    "      #maxv = scaling_df[ scaling_df['input'] == column_name ]['max'].values[0]\n",
    "      #scaled_value = (input_values[column_name] - minv)/(maxv - minv)\n",
    "      meanv = scaling_df[ scaling_df['input'] == column_name ]['mean'].values[0]\n",
    "      stddev = scaling_df[ scaling_df['input'] == column_name ]['stddev'].values[0]\n",
    "      scaled_value = (input_values[column_name] - meanv)/stddev\n",
    "    else:\n",
    "      scaled_value = 1.0\n",
    "    contrib = wt * scaled_value\n",
    "    print('col={} wt={} scaled_value={} contrib={}'.format(column_name, wt, scaled_value, contrib))\n",
    "    pred = pred + contrib\n",
    "  # categorical inputs\n",
    "  for column_name in categorical_weights['input'].unique():\n",
    "    category_weights = categorical_weights[ categorical_weights['input'] == column_name ]\n",
    "    wt = category_weights[ category_weights['category_name'] == input_values[column_name] ]['category_weight'].values[0]\n",
    "    print('col={} wt={} value={} contrib={}'.format(column_name, wt, input_values[column_name], wt))\n",
    "    pred = pred + wt\n",
    "  return pred"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here is an example of the prediction code in action:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "col=dep_delay wt=36.5569545237 scaled_value=-0.329782492822 contrib=-12.0558435928\n",
      "col=taxi_out wt=8.15557957221 scaled_value=0.213461991601 contrib=1.74090625815\n",
      "col=distance wt=-1.88324519311 scaled_value=0.672196648431 contrib=-1.26591110699\n",
      "col=__INTERCEPT__ wt=1.09017737502 scaled_value=1.0 contrib=1.09017737502\n",
      "col=carrier wt=-0.0548843604154 value=AA contrib=-0.0548843604154\n",
      "col=origin wt=0.966535564037 value=DFW contrib=0.966535564037\n",
      "col=dest wt=1.26816262538 value=LAX contrib=1.26816262538\n",
      "-8.310857237611458\n"
     ]
    }
   ],
   "source": [
    "rowdict = {\n",
    "  'carrier' : 'AA',\n",
    "  'origin': 'DFW',\n",
    "  'dest': 'LAX',\n",
    "  'dep_delay': -3,\n",
    "  'taxi_out': 18,\n",
    "  'distance': 1235\n",
    "}\n",
    "print(compute_prediction(rowdict, numeric_weights,  scaling_df, categorical_weights))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As you can see, this matches the batch prediction value, telling us that we got the computation correct.\n",
    "\n",
    "Now, all that we have to do is to wrap up the code into a web application to get online prediction.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Copyright 2017 Google Inc. Licensed under the Apache License, Version 2.0 (the \\\"License\\\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \\\"AS IS\\\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
